{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RNN的output和state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'2.0.0'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "tf.__version__"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 伪造输入张量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TensorShape([2, 4, 16])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "seq = tf.constant([1,2,3,4,5,6,7,8], shape=(2,4))\n",
    "embedding = tf.keras.layers.Embedding(10,16)\n",
    "\n",
    "inp = embedding(seq)\n",
    "\n",
    "inp.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## return_sequences=False, return_state=False\n",
    "\n",
    "因为`return_sequences=False`和`return_state=False`，所以`LSTM`返回的只有最后一个timestep的输出，即`output`。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(2, 16)\n"
     ]
    }
   ],
   "source": [
    "enc = tf.keras.layers.LSTM(16)\n",
    "\n",
    "output = enc(inp)\n",
    "\n",
    "print(output.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## return_sequences=True, return_state=False\n",
    "\n",
    "因为`return_sequences=True`，所以会输出每一个timestep的`output`。因此，这里的`output`张量比上面增加了一个维度，`shape:(batch_size, time_steps, units)`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(2, 4, 16)\n"
     ]
    }
   ],
   "source": [
    "enc2 = tf.keras.layers.LSTM(16, return_sequences=True, return_state=False)\n",
    "\n",
    "output = enc2(inp)\n",
    "\n",
    "print(output.shape)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## return_sequences=False, return_state=True\n",
    "\n",
    "因为`return_state=True`，所以LSTM会返回最后一个timestep的`state`，实际上`RNN`的`state`分为两个`state_h`和`state_c`。所以：\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(2, 16)\n",
      "(2, 16)\n",
      "(2, 16)\n"
     ]
    }
   ],
   "source": [
    "enc3 = tf.keras.layers.LSTM(16, return_sequences=False, return_state=True)\n",
    "\n",
    "output, state_h, state_c = enc3(inp)\n",
    "\n",
    "print(output.shape)\n",
    "print(state_h.shape)\n",
    "print(state_c.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## return_sequences=True, return_state=True\n",
    "\n",
    "因为`return_sequences=True`并且`return_state=True`，所以LSTM会返回每个timestep的输出组成的`output`，以及最后一个timestep的`state`（分为`state_h`和`state_c`两个)。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(2, 4, 16)\n",
      "(2, 16)\n",
      "(2, 16)\n"
     ]
    }
   ],
   "source": [
    "enc4 = tf.keras.layers.LSTM(16, return_sequences=True, return_state=True)\n",
    "\n",
    "output, state_h, state_c = enc4(inp)\n",
    "\n",
    "print(output.shape)\n",
    "print(state_h.shape)\n",
    "print(state_c.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "通过以上小实验我们可以得出以下结论：\n",
    "\n",
    "* RNN的输出，可以通过`return_sequences`和`return_state`两个参数控制\n",
    "* `return_sequences=True`说明返回每一个timestep的输出，组成最终的`output`，`shape: (batch_size, time_steps, units)`\n",
    "* `return_sequences=False`说明返回最后一个timestep的输出，组成最终的`output`，`shape: (batch_size, units)`\n",
    "* `return_state=True`说明返回最后一个timestep的状态，组成最终的`state`\n",
    "* `return_state=False`说明不返回最后一个timestep的状态，LSTM的输出仅仅有`output`\n",
    "\n",
    "\n",
    "还有一个trick可以帮助理解：\n",
    "\n",
    "* `return_sequences`用了`sequence`的复数表示`sequences`，说明返回的是一个序列的输出，组成最后的`output`\n",
    "* `return_state`没用`state`的复数表示`states`，说明仅仅返回最后一个timestep的状态，组成最后的`state`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
